# model_lognormal
#
# FUNCTIONS for baby model with NIPT
# version with recommendation for invasive (version 2 in write-up)
#

residvar <- function(did_invasive, pi) {
  reg = lm(did_invasive ~ pi)
  return(var(reg$resid))
}


apply_meanshift <- function(formula, data, beta) {
  # Construct data matrix from formula
  X = model.matrix(formula, data)
  # Verify conformability
  n_have = ncol(X)
  n_given = length(beta)

  if (n_have != n_given) {
    print(colnames(X))
    print("beta:")
    print(beta)
    stop(paste0(
      "Number of variables implied by formula (",
      n_have,
      ") does not match the amount given (",
      n_given,
      ").")
    )
  }

  # Predict Y_hat
  Y_hat = X %*% beta
  return(Y_hat)
}


# Function to calculate actual moments in the data + associated weights
moments_wgts <- function(data){

  n_nipt <- data %>%
    filter(wave == 3) %>%
    filter(is.na(did_nipt)==FALSE) %>%
    group_by(bin_number, policy_regime) %>%
    tally(name="n_nipt") %>%
    arrange(policy_regime, bin_number)

  mean_nipt <- data %>%
    filter(wave == 3) %>%
    filter(is.na(did_nipt)==FALSE) %>%
    group_by(bin_number, policy_regime) %>%
    summarise(mean_nipt = mean(did_nipt)) %>%
    arrange(policy_regime, bin_number)

  n_invN <- data %>%
    filter(wave == 3) %>%
    filter(did_nipt==0) %>%
    filter(is.na(did_invasive)==FALSE) %>%
    group_by(bin_number, policy_regime) %>%
    tally(name="n_invN") %>%
    arrange(policy_regime, bin_number)

  mean_invN <- data %>%
    filter(wave == 3) %>%
    filter(did_nipt==0) %>%
    filter(is.na(did_invasive)==FALSE) %>%
    group_by(bin_number, policy_regime) %>%
    summarise(mean_invN = mean(did_invasive)) %>%
    arrange(policy_regime, bin_number)

  n_inv_w2 <- data %>%
    filter(wave == 2) %>%
    filter(is.na(did_invasive)==FALSE) %>%
    group_by(bin_number) %>%
    tally(name="n_inv_w2") %>%
    arrange(bin_number)

  mean_inv_w2 <- data %>%
    filter(wave == 2) %>%
    filter(is.na(did_invasive)==FALSE) %>%
    group_by(bin_number) %>%
    summarise(mean_inv_w2 = mean(did_invasive)) %>%
    arrange(bin_number)


  # way to deal with no one in bin doing NIPT -> put in zeros for invY (+ invN) - temporary solution
  total <- left_join(n_nipt, mean_nipt, by=c("policy_regime", "bin_number")) %>%
    left_join(n_invN, by=c("policy_regime", "bin_number")) %>%
    left_join(mean_invN, by=c("policy_regime", "bin_number")) %>%
    mutate(n_invN = ifelse(is.na(n_invN)==TRUE, 0, n_invN)) %>%
    mutate(mean_invN = ifelse(is.na(mean_invN)==TRUE, 0, mean_invN))

  total_w2 <- left_join(n_inv_w2, mean_inv_w2, by="bin_number") %>%
    mutate(n_inv_w2 = ifelse(is.na(n_inv_w2)==TRUE, 0, n_inv_w2)) %>%
    mutate(mean_inv_w2 = ifelse(is.na(mean_inv_w2)==TRUE, 0, mean_inv_w2))

  total_n_nipt <- total %>%
    dplyr::select(bin_number, n_nipt) %>%
    as.matrix()
  total_mean_nipt <- total %>%
    dplyr::select(bin_number, mean_nipt) %>%
    as.matrix()
  total_n_invN <- total %>%
    dplyr::select(bin_number, n_invN) %>%
    as.matrix()
  total_mean_invN <- total %>%
    dplyr::select(bin_number, mean_invN) %>%
    as.matrix()
  total_n_inv_w2 <- total_w2 %>%
    dplyr::select(bin_number, n_inv_w2) %>%
    as.matrix()
  total_mean_inv_w2 <- total_w2 %>%
    dplyr::select(bin_number, mean_inv_w2) %>%
    as.matrix()

  moments <- rbind(total_mean_nipt, total_mean_invN, total_mean_inv_w2)
  moments <- moments[,2]

  wgts <- rbind(total_n_nipt, total_n_invN, total_n_inv_w2)
  wgts <- rbind(wgts)
  wgts <- wgts[,2]

  return(list(moments, wgts))
}


# Function that draws a_i and c_i for each observation in data given parameters
# draw_a_c <- function(par, data){
draw_a_c <- function(par, data, f_a, f_c, J=1, x_vars, unique_flag=1){
  # data is 30479 x 52

  a_mean <- par[1]
  c_mean <- par[2]
  a_sd <- par[3]
  c_sd <- par[4]
  rho <- par[5]

  # pull in data
  sim <- data %>%
    dplyr::select(pregnancy, fetus_risk, p_i, wave, oop_i, bin_number, policy_id, policy_regime, all_of(x_vars))
  sim <- as.matrix(sim)
  colnames(sim) <- c("pregnancy", "fetus_risk", "p_i", "wave", "oop_i", "bin_number", "policy_id", "policy_regime", x_vars)
  
  # replicate J times to make J draws
  sim_1 <- sim
  if(J > 1){
    for (i in 1:(J-1)) {
      sim <- rbind(sim, sim_1)
    }
  }

  # draw a's and c's
  # set parameters for bivariate normal distribution
  sigma <- matrix(c(a_sd^2, a_sd*c_sd*rho, a_sd*c_sd*rho, c_sd^2), 2, 2) # Covariance matrix
  sigma <- make.positive.definite(sigma, tol=1e-8) # Fix for sigma sometimes not being positive definite due to optimization

  tryCatch({
    low <- c(-Inf, -Inf)
    upp <- c(0,Inf)
    N <- nrow(sim)
    
    set.seed(238374)
    sim <- data.frame(sim)

    # Apply mean shifters
    N_a = ncol(model.matrix(f_a, sim))
    N_c = ncol(model.matrix(f_c, sim))

    a_coefs = par[7:(7 + N_a - 1)]
    c_coefs = par[(7+N_a):length(par)]

    a_mean_shifted = a_mean + apply_meanshift(f_a, data, a_coefs) # 30479 x 1
    c_mean_shifted = c_mean + apply_meanshift(f_c, data, c_coefs) # 30479 x 1
    
    sim <- cbind(sim, a_mean_shifted, c_mean_shifted) %>%
      mutate(original_order = row_number()) # 30479 x 18
    
    mu <- cbind(a_mean_shifted, c_mean_shifted) # 30479 x 2
    mu <- mu %>%
      as.data.frame(mu) %>%
      group_by(V1, V2) %>%
      mutate(group = cur_group_id()) %>%
      ungroup() # 30479 x 3

    if (unique_flag == 1) {
      mu_unique <- unique(mu) # 1 x 3
      sample <- mu_unique
      if (count(mu) != nrow(data)) {
        print("wrong length")
      }
      colnames(mu) <- c("a_mean_shifted", "c_mean_shifted", "group")
      colnames(mu_unique) <- c("a_mean_shifted", "c_mean_shifted", "group")
      sim <- sim %>% 
        group_by(a_mean_shifted, c_mean_shifted) %>%
        mutate(group = cur_group_id()) %>%
        ungroup() # 30479 x 19
      sample <-sample %>%
        select(-group)
    } else {
      sample <- mu
    }

    # # Draw from the untruncated distribution first, check truncation
    # rejection_rates = t(apply(
    #  sample,
    #  1,
    #  function(x){
    #    draws = mvrnorm(n=1000, mu=x, Sigma=sigma) # 1000 x 2
    #    reject_rate = mean(draws[,1] >= 0)
    #    return(reject_rate)
    #  }
    # )) # 1 x 1

    # # Check the rejection rates -- if any are excess of 95%, reject.
    # all_rejection_rates = sum(rejection_rates >= 0.95)
    all_rejection_rates = 0
    
    ### move seed to right before draws
    set.seed(3491)
    if (all_rejection_rates > 0) {
      # The parameters are garbage, set the draws to the bounds.
      bvn <- matrix(upp, nrow=nrow(mu), ncol=2, byrow=TRUE)
      colnames(bvn) <- c("a_i", "c_i")
      sim <- data.frame(cbind(sim, bvn))
    } else {
      # all good, draw from truncated bivariate normal distribution
      # maybe faster to instead, draw n times from x number of unique distributions, matching each obs to its mu_unique group
      
      # max_group <- max(mu$group)
      # for (i in 1:max_group) {
      #   mu_temp <- mu %>%
      #     filter(group == i)
      #   mu_temp <- mu_temp %>%
      #     dplyr::select(-group)
      #   samplesize <- nrow(mu_temp)
      #   ### take the means and sigma from each group. Then, create a distribution and sample n random values. Then merge them onto mu_temp
      #   mean_dist <- mu_unique[i,] %>%
      #     dplyr::select(-group)
      #   mean_dist_vec <- unname(unlist(mean_dist[1,]))
      #   bvn_temp <- rtmvnorm(n=samplesize, mean = mean_dist_vec, sigma = sigma, lower = low, upper = upp, algorithm="rejection")
      #   bvn_temp <- as.data.frame(bvn_temp)
      #   bvn_temp <- bvn_temp %>%
      #     mutate(group = i) %>%
      #     mutate(n = row_number())
      #   if (i == 1) {
      #     bvn <- bvn_temp
      #   } else {
      #     bvn <- rbind(bvn, bvn_temp)
      #   }
      # }

      mean_dist <- mu_unique[1,] %>%
          dplyr::select(-group)
      mean_dist_vec <- unname(unlist(mean_dist[1,]))
      # bvn <- rtmvnorm(n=N, mean=mean_dist_vec, sigma=sigma, lower=low, upper=upp, algorithm="rejection")
      bvn <- rlnorm.rplus(n=N, meanlog=mean_dist_vec, varlog=sigma)
      bvn <- as.data.frame(bvn)
      bvn <- bvn %>%
        mutate(group = 1) %>%
        mutate(n = row_number()) %>%
        mutate(V1 = -V1) %>%
        mutate(V2 = -V2)
      print("mean")
      print(mean(bvn$V1))
      print(mean(bvn$V2))
      print("sd")
      print(sd(bvn$V1))
      print(sd(bvn$V2))
      print("realized correlation:")
      print(cor(bvn$V1, bvn$V2))

      sim <- sim %>%
        group_by(group) %>%
        mutate(n = row_number()) %>%
        ungroup() %>% 
        dplyr::inner_join(bvn, by = c("group", "n")) %>%
        rename(a_i = V1) %>%
        rename(c_i = V2) %>%
        ungroup() %>%
        arrange(original_order) %>%
        select(-a_mean_shifted, -c_mean_shifted, -group, -n, -original_order)
    }

    # keep variables necessary for simulating decisions
    out <- cbind(sim)

    return(out)},

    error = function(x){
      print(x)
      return(NA)
    }
  )
}


# Function that calculates decisions for each observation i
model_decisions <- function(par, data, f_a, f_c, x_vars, unique_flag=1){
  a_mean <- par[1]
  c_mean <- par[2]
  a_sd <- par[3]
  c_sd <- par[4]
  rho <- par[5]
  psi <- par[6]

  # draw a_i and c_i for each observation in data given parameters
  draw <- draw_a_c(par=par, data=data, f_a=f_a, f_c=f_c, unique_flag=unique_flag, x_vars=x_vars)
  colnames(draw)
  if (length(draw) == 1 && is.na(draw)) {
    return(NA)
  }

  final <- as.data.frame(draw)

  # calculate decisions according to model for each i
  # recommendation to do invasive based on KUB result causes p_i --> p_i^psi everywhere on decision tree
  final <- final %>%
    mutate(rec = ifelse(p_i >= 1/200, 1, 0)) %>%
    mutate(p = ifelse(rec==1, p_i^psi, p_i))

  # true_pr_pos = true probability that you get a positive cfDNA test
  # pr_pos = your belief about getting a positive cfDNA test
  # pr_c_pos = posterior (after positive cfDNA test)
  # pr_c_neg = posterior (after negative cfDNA test)
  final <- final %>%
    mutate(true_pr_pos = (1-p_i)*pfp + p_i*(1-pfn)) %>%
    mutate(pr_pos = (1-p)*pfp + p*(1-pfn)) %>%
    mutate(pr_c_pos = (1-pfn)*p/pr_pos) %>%
    mutate(pr_c_neg = pfn*p/(1-pr_pos))

  # ui2_Y0 = expected utility, conditional invasive and no NIPT
  final <- final %>% # miscarriage rate marker
    mutate(ui2_Y0 = pa_belief*a_i + (1-pa_belief)*p*pmax(a_i, c_i)) %>%
    mutate(ui2_N0 = pmax(a_i, p*c_i)) %>%
    mutate(ui2_YP = pa_belief*a_i + (1-pa_belief)*pr_c_pos*pmax(a_i, c_i) - oop_i) %>%
    mutate(ui2_NP = pmax(a_i, pr_c_pos*c_i) - oop_i) %>%
    mutate(ui2_YN = pa_belief*a_i + (1-pa_belief)*pr_c_neg*pmax(a_i, c_i) - oop_i) %>%
    mutate(ui2_NN = pmax(a_i, pr_c_neg*c_i) - oop_i)

  final <- final %>%
    mutate(invasive_P = ifelse(ui2_YP >= ui2_NP, 1, 0)) %>%
    mutate(invasive_N = ifelse(ui2_YN >= ui2_NN, 1, 0)) %>%
    mutate(invasive_0 = ifelse(ui2_Y0 >= ui2_N0, 1, 0))

  final <- final %>%
    mutate(ui1_Y = pr_pos*pmax(ui2_YP, ui2_NP) + (1-pr_pos)*pmax(ui2_YN, ui2_NN)) %>%
    mutate(ui1_N = pmax(ui2_Y0, ui2_N0))

  final <- final %>%
    mutate(nipt_indiff = ifelse(round(ui1_Y, 5) == round(ui1_N, 5), 1, 0)) %>%
    mutate(pred_nipt = ifelse(nipt_indiff==0,
                              ifelse(ui1_Y > ui1_N, 1, 0), 0))
  
  final <- final %>%
    mutate(pred_invasive = ifelse(pred_nipt==0, invasive_0, true_pr_pos*invasive_P + (1-true_pr_pos)*invasive_N))
  # # Fix for pred_invasive for nipt_indiff == 1 where would do invasive regardless
  # final <- final %>%
  #   mutate(pred_invasive = ifelse(nipt_indiff == 1 & (invasive_P == invasive_N) & invasive_P == 1, 1, pred_invasive))
  final <- final %>%
    dplyr::select(p_i, wave, oop_i, bin_number, policy_id, policy_regime, a_i, c_i, pred_nipt, pred_invasive, all_of(x_vars))

  temp <- final %>% filter(wave == 3 & bin_number == 3)

  table(temp$nipt_indiff)

  return(final)
}


# Function that uses the previous one to calculate model moments and the objective function
obj <- function(par, data, data_moments, weights, f_a, f_c, x_vars, unique_flag=1, info=FALSE){
  
  # calculate model-predicted decisions
  d = model_decisions(par=par, data=data, f_a=f_a, f_c=f_c, x_vars=x_vars, unique_flag=unique_flag)
  if (length(d) == 1 && is.na(d)) {
    return(Inf)
  }
  decisions <-d  %>%
    as.data.frame()

  nipt <- decisions %>%
    filter(wave == 3) %>%
    group_by(bin_number, policy_regime) %>%
    summarise(mean_nipt = mean(pred_nipt), .groups="drop") %>%
    arrange(policy_regime, bin_number)

  invN <- decisions %>%
    filter(wave == 3) %>%
    filter(pred_nipt==0) %>%
    group_by(bin_number, policy_regime) %>%
    summarise(mean_invN = mean(pred_invasive), .groups="drop") %>%
    arrange(policy_regime, bin_number)

  inv_w2 <- decisions %>%
    filter(wave == 2) %>%
    group_by(bin_number, policy_regime) %>%
    summarise(mean_inv_w2 = mean(pred_invasive), .groups="drop") %>%
    arrange(policy_regime, bin_number)

  total <- left_join(nipt, invN, by=c("bin_number", "policy_regime"))

  model_moments <- c(
    total$mean_nipt %>% as.vector,
    total$mean_invN %>% as.vector,
    inv_w2$mean_inv_w2 %>% as.vector
  )

  # weighted sum of squared errors
  error <- model_moments - data_moments
  wgtsum <- sum(error*weights*error, na.rm=T)

  if (info == FALSE) {
    output <- append(par, wgtsum)
    output <- t(output)
    write.table(output, file=paste0(TEMP, "/estimate_w_xs/estimates_obj.csv"), sep = ",", na=".", append = TRUE, col.names = FALSE, row.names = FALSE)
    write.table(output, file=paste0(RESULTS, "/estimate_w_xs/estimates_obj.csv"), sep = ",", na=".", append = TRUE, col.names = FALSE, row.names = FALSE)
    return(wgtsum)

  } else {
    return(list(
      model=model_moments,
      data=data_moments,
      error=error,
      objective=wgtsum,
      total=total,
      bins=total$bin_number,
      moment_names=c(
        paste0("mean_nipt_bin_",total$bin_number, "_regime_", total$policy_regime),
        paste0("mean_invN_bin_",total$bin_number, "_regime_", total$policy_regime),
        paste0("mean_inv_w2_bin_",inv_w2$bin_number)
      ),
      a_i = decisions$a_i,
      c_i = decisions$c_i,
      decisions = decisions
    ))
  }
}